{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ssms\n",
    "import lanfactory\n",
    "import os\n",
    "import numpy as np\n",
    "from copy import deepcopy\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL = \"ddm\"\n",
    "RUN_SIMS = False\n",
    "DEVICE = \"cpu\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize the generator config (for MLP LANs)\n",
    "generator_config = deepcopy(ssms.config.data_generator_config[\"lan\"])\n",
    "# Specify generative model (one from the list of included models mentioned above)\n",
    "generator_config[\"model\"] = MODEL\n",
    "# Specify number of parameter sets to simulate\n",
    "generator_config[\"n_parameter_sets\"] = 256\n",
    "# Specify how many samples a simulation run should entail\n",
    "generator_config[\"n_samples\"] = 2000\n",
    "# Specify folder in which to save generated data\n",
    "generator_config[\"output_folder\"] = \"data/lan_mlp/\"\n",
    "\n",
    "# Make model config dict\n",
    "model_config = ssms.config.model_config[MODEL]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'name': 'ddm',\n",
       " 'params': ['v', 'a', 'z', 't'],\n",
       " 'param_bounds': [[-3.0, 0.3, 0.1, 0.0], [3.0, 2.5, 0.9, 2.0]],\n",
       " 'boundary_name': 'constant',\n",
       " 'boundary': <function ssms.basic_simulators.boundary_functions.constant(t=0)>,\n",
       " 'boundary_params': [],\n",
       " 'n_params': 4,\n",
       " 'default_params': [0.0, 1.0, 0.5, 0.001],\n",
       " 'nchoices': 2,\n",
       " 'simulator': <cyfunction ddm_flexbound at 0x11aa37100>}"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyError",
     "evalue": "'dgp_list'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyError\u001b[0m                                  Traceback (most recent call last)",
      "\u001b[1;32m/Users/afengler/Library/CloudStorage/OneDrive-Personal/project_lanfactory/LANfactory/notebooks/test_notebooks/test_jax_network_cpn.ipynb Cell 5\u001b[0m line \u001b[0;36m3\n\u001b[1;32m      <a href='vscode-notebook-cell:/Users/afengler/Library/CloudStorage/OneDrive-Personal/project_lanfactory/LANfactory/notebooks/test_notebooks/test_jax_network_cpn.ipynb#W4sZmlsZQ%3D%3D?line=0'>1</a>\u001b[0m generator_config[\u001b[39m\"\u001b[39m\u001b[39moutput_folder\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m (\n\u001b[1;32m      <a href='vscode-notebook-cell:/Users/afengler/Library/CloudStorage/OneDrive-Personal/project_lanfactory/LANfactory/notebooks/test_notebooks/test_jax_network_cpn.ipynb#W4sZmlsZQ%3D%3D?line=1'>2</a>\u001b[0m     \u001b[39m\"\u001b[39m\u001b[39mdata/lan_mlp/\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[0;32m----> <a href='vscode-notebook-cell:/Users/afengler/Library/CloudStorage/OneDrive-Personal/project_lanfactory/LANfactory/notebooks/test_notebooks/test_jax_network_cpn.ipynb#W4sZmlsZQ%3D%3D?line=2'>3</a>\u001b[0m     \u001b[39m+\u001b[39m generator_config[\u001b[39m\"\u001b[39;49m\u001b[39mdgp_list\u001b[39;49m\u001b[39m\"\u001b[39;49m]\n\u001b[1;32m      <a href='vscode-notebook-cell:/Users/afengler/Library/CloudStorage/OneDrive-Personal/project_lanfactory/LANfactory/notebooks/test_notebooks/test_jax_network_cpn.ipynb#W4sZmlsZQ%3D%3D?line=3'>4</a>\u001b[0m     \u001b[39m+\u001b[39m \u001b[39m\"\u001b[39m\u001b[39m/\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m      <a href='vscode-notebook-cell:/Users/afengler/Library/CloudStorage/OneDrive-Personal/project_lanfactory/LANfactory/notebooks/test_notebooks/test_jax_network_cpn.ipynb#W4sZmlsZQ%3D%3D?line=4'>5</a>\u001b[0m     \u001b[39m+\u001b[39m \u001b[39mstr\u001b[39m(generator_config[\u001b[39m\"\u001b[39m\u001b[39mn_samples\u001b[39m\u001b[39m\"\u001b[39m])\n\u001b[1;32m      <a href='vscode-notebook-cell:/Users/afengler/Library/CloudStorage/OneDrive-Personal/project_lanfactory/LANfactory/notebooks/test_notebooks/test_jax_network_cpn.ipynb#W4sZmlsZQ%3D%3D?line=5'>6</a>\u001b[0m     \u001b[39m+\u001b[39m \u001b[39m\"\u001b[39m\u001b[39m_\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m      <a href='vscode-notebook-cell:/Users/afengler/Library/CloudStorage/OneDrive-Personal/project_lanfactory/LANfactory/notebooks/test_notebooks/test_jax_network_cpn.ipynb#W4sZmlsZQ%3D%3D?line=6'>7</a>\u001b[0m     \u001b[39m+\u001b[39m \u001b[39mstr\u001b[39m(generator_config[\u001b[39m\"\u001b[39m\u001b[39mn_training_samples_by_parameter_set\u001b[39m\u001b[39m\"\u001b[39m])\n\u001b[1;32m      <a href='vscode-notebook-cell:/Users/afengler/Library/CloudStorage/OneDrive-Personal/project_lanfactory/LANfactory/notebooks/test_notebooks/test_jax_network_cpn.ipynb#W4sZmlsZQ%3D%3D?line=7'>8</a>\u001b[0m     \u001b[39m+\u001b[39m \u001b[39m\"\u001b[39m\u001b[39m/\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m      <a href='vscode-notebook-cell:/Users/afengler/Library/CloudStorage/OneDrive-Personal/project_lanfactory/LANfactory/notebooks/test_notebooks/test_jax_network_cpn.ipynb#W4sZmlsZQ%3D%3D?line=8'>9</a>\u001b[0m )\n",
      "\u001b[0;31mKeyError\u001b[0m: 'dgp_list'"
     ]
    }
   ],
   "source": [
    "generator_config[\"output_folder\"] = (\n",
    "    \"data/lan_mlp/\"\n",
    "    + generator_config[\"dgp_list\"]\n",
    "    + \"/\"\n",
    "    + str(generator_config[\"n_samples\"])\n",
    "    + \"_\"\n",
    "    + str(generator_config[\"n_training_samples_by_parameter_set\"])\n",
    "    + \"/\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if RUN_SIMS:\n",
    "    n_datafiles = 20\n",
    "    for i in range(n_datafiles):\n",
    "        print(\"Datafile: \", i)\n",
    "        my_dataset_generator = ssms.dataset_generators.lan_mlp.data_generator(\n",
    "            generator_config=generator_config, model_config=model_config\n",
    "        )\n",
    "        training_data = my_dataset_generator.generate_data_training_uniform(save=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "folder_ = \"../data/lan_mlp/\" + MODEL + \"/\"\n",
    "files_ = [folder_ + file_ for file_ in os.listdir(folder_)]\n",
    "\n",
    "my_data = pickle.load(\n",
    "    open(\n",
    "        files_[0],\n",
    "        \"rb\",\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Network config: \n",
      "{'layer_sizes': [100, 100, 100, 1], 'activations': ['tanh', 'tanh', 'tanh', 'linear'], 'train_output_type': 'logprob'}\n",
      "Train config: \n",
      "{'cpu_batch_size': 128, 'gpu_batch_size': 256, 'n_epochs': 5, 'optimizer': 'adam', 'learning_rate': 0.002, 'lr_scheduler': 'reduce_on_plateau', 'lr_scheduler_params': {}, 'weight_decay': 0.0, 'loss': 'huber', 'save_history': True}\n"
     ]
    }
   ],
   "source": [
    "network_config = deepcopy(lanfactory.config.network_configs.network_config_mlp)\n",
    "network_config[\"layer_sizes\"] = [100, 100, 100, 1]\n",
    "network_config[\"activations\"] = [\"tanh\", \"tanh\", \"tanh\", \"linear\"]\n",
    "\n",
    "print(\"Network config: \")\n",
    "print(network_config)\n",
    "\n",
    "train_config = deepcopy(lanfactory.config.network_configs.train_config_mlp)\n",
    "\n",
    "print(\"Train config: \")\n",
    "print(train_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "network_config[\"train_output_type\"] = \"logits\"\n",
    "\n",
    "\n",
    "train_config[\"loss\"] = \"bcelogit\"\n",
    "train_config[\"cpu_batch_size\"] = 1024\n",
    "train_config[\"gpu_batch_size\"] = 1024\n",
    "train_config[\"n_epochs\"] = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_ = \"../data/lan_mlp/\" + MODEL + \"/\"\n",
    "file_list_ = [folder_ + file_ for file_ in os.listdir(folder_)]\n",
    "\n",
    "# Training dataset\n",
    "jax_training_dataset = lanfactory.trainers.DatasetTorch(\n",
    "    file_ids=file_list_,\n",
    "    batch_size=(\n",
    "        train_config[DEVICE + \"_batch_size\"]\n",
    "        if torch.cuda.is_available()\n",
    "        else train_config[DEVICE + \"_batch_size\"]\n",
    "    ),\n",
    "    label_lower_bound=np.log(1e-10),\n",
    "    features_key=\"cpn_data\",\n",
    "    label_key=\"cpn_labels\",\n",
    "    out_framework=\"jax\",\n",
    ")\n",
    "\n",
    "jax_training_dataloader = torch.utils.data.DataLoader(\n",
    "    jax_training_dataset, shuffle=True, batch_size=None, num_workers=1, pin_memory=True\n",
    ")\n",
    "\n",
    "# Validation dataset\n",
    "jax_validation_dataset = lanfactory.trainers.DatasetTorch(\n",
    "    file_ids=file_list_,\n",
    "    batch_size=(\n",
    "        train_config[DEVICE + \"_batch_size\"]\n",
    "        if torch.cuda.is_available()\n",
    "        else train_config[DEVICE + \"_batch_size\"]\n",
    "    ),\n",
    "    label_lower_bound=np.log(1e-10),\n",
    "    features_key=\"cpn_data\",\n",
    "    label_key=\"cpn_labels\",\n",
    "    out_framework=\"jax\",\n",
    ")\n",
    "\n",
    "jax_validation_dataloader = torch.utils.data.DataLoader(\n",
    "    jax_validation_dataset,\n",
    "    shuffle=True,\n",
    "    batch_size=None,\n",
    "    num_workers=1,\n",
    "    pin_memory=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.0230,  0.8928,  0.2771,  1.8274],\n",
      "        [ 0.0773,  1.1757,  0.1404,  0.9382],\n",
      "        [-0.8907,  2.2692,  0.7324,  1.0069],\n",
      "        ...,\n",
      "        [-1.3451,  1.1634,  0.6648,  1.9732],\n",
      "        [ 2.5755,  0.7371,  0.6242,  0.0403],\n",
      "        [-1.3828,  1.4075,  0.5868,  1.0250]])\n",
      "tensor([[0.7155],\n",
      "        [0.8210],\n",
      "        [0.8890],\n",
      "        ...,\n",
      "        [0.8740],\n",
      "        [0.0070],\n",
      "        [0.9635]])\n",
      "tensor([[-2.1197,  1.4374,  0.8743,  1.9186],\n",
      "        [-1.6401,  2.4682,  0.8981,  0.8245],\n",
      "        [-2.2833,  0.9635,  0.2890,  1.5146],\n",
      "        ...,\n",
      "        [-2.2025,  0.7549,  0.2196,  0.9994],\n",
      "        [-2.9090,  1.2080,  0.8279,  0.4686],\n",
      "        [-1.4978,  0.4554,  0.1142,  1.9158]])\n",
      "tensor([[0.8000],\n",
      "        [0.8315],\n",
      "        [0.9985],\n",
      "        ...,\n",
      "        [0.9965],\n",
      "        [0.9270],\n",
      "        [0.9765]])\n",
      "tensor([[ 0.4716,  0.4065,  0.7690,  1.8287],\n",
      "        [ 0.8742,  0.9753,  0.1557,  0.0602],\n",
      "        [-1.0113,  0.6423,  0.4338,  1.5325],\n",
      "        ...,\n",
      "        [-0.7304,  0.6087,  0.6846,  1.6364],\n",
      "        [-0.1057,  0.8700,  0.3594,  0.7324],\n",
      "        [-0.9033,  2.1455,  0.4029,  1.1604]])\n",
      "tensor([[0.1710],\n",
      "        [0.5740],\n",
      "        [0.8555],\n",
      "        ...,\n",
      "        [0.5320],\n",
      "        [0.6810],\n",
      "        [0.9875]])\n",
      "tensor([[ 0.3448,  0.6420,  0.6593,  1.1213],\n",
      "        [-2.1827,  0.3602,  0.7042,  1.2803],\n",
      "        [-0.9700,  1.0939,  0.1077,  1.7086],\n",
      "        ...,\n",
      "        [ 0.4895,  0.8776,  0.6395,  1.5455],\n",
      "        [-1.2850,  1.2958,  0.1047,  0.2472],\n",
      "        [-2.6124,  0.5931,  0.7063,  0.0635]])\n",
      "tensor([[0.2680],\n",
      "        [0.6730],\n",
      "        [0.9940],\n",
      "        ...,\n",
      "        [0.1925],\n",
      "        [0.9990],\n",
      "        [0.8390]])\n",
      "tensor([[ 0.3448,  0.6420,  0.6593,  1.1213],\n",
      "        [-2.1827,  0.3602,  0.7042,  1.2803],\n",
      "        [-0.9700,  1.0939,  0.1077,  1.7086],\n",
      "        ...,\n",
      "        [ 0.4895,  0.8776,  0.6395,  1.5455],\n",
      "        [-1.2850,  1.2958,  0.1047,  0.2472],\n",
      "        [-2.6124,  0.5931,  0.7063,  0.0635]])\n",
      "tensor([[0.2680],\n",
      "        [0.6730],\n",
      "        [0.9940],\n",
      "        ...,\n",
      "        [0.1925],\n",
      "        [0.9990],\n",
      "        [0.8390]])\n",
      "tensor([[-2.7455,  0.4938,  0.2208,  0.6867],\n",
      "        [ 0.4126,  1.7031,  0.6554,  1.3090],\n",
      "        [-1.6596,  0.4897,  0.3812,  0.0976],\n",
      "        ...,\n",
      "        [ 1.8655,  0.9064,  0.6965,  1.2548],\n",
      "        [ 2.9318,  0.7008,  0.2329,  0.3484],\n",
      "        [ 2.6446,  0.6313,  0.5637,  1.0006]])\n",
      "tensor([[0.9925],\n",
      "        [0.1055],\n",
      "        [0.9030],\n",
      "        ...,\n",
      "        [0.0050],\n",
      "        [0.1375],\n",
      "        [0.0195]])\n",
      "tensor([[ 2.7574,  1.3463,  0.3169,  1.0711],\n",
      "        [ 2.6068,  1.4262,  0.2526,  1.0733],\n",
      "        [ 0.2186,  0.4270,  0.1091,  0.9134],\n",
      "        ...,\n",
      "        [-0.6603,  1.1238,  0.5937,  1.0083],\n",
      "        [ 1.1378,  0.3621,  0.5986,  1.9335],\n",
      "        [ 0.3604,  0.3021,  0.5881,  1.0940]])\n",
      "tensor([[0.0080],\n",
      "        [0.0265],\n",
      "        [0.8475],\n",
      "        ...,\n",
      "        [0.7355],\n",
      "        [0.2260],\n",
      "        [0.3620]])\n",
      "tensor([[ 1.7703,  0.5412,  0.3040,  0.5476],\n",
      "        [-1.0968,  0.6122,  0.5794,  0.7160],\n",
      "        [ 0.7944,  2.1391,  0.2862,  0.0507],\n",
      "        ...,\n",
      "        [-2.2702,  1.8492,  0.8818,  1.1966],\n",
      "        [ 1.9210,  0.4205,  0.6835,  0.1054],\n",
      "        [ 0.7131,  2.4640,  0.7410,  1.8090]])\n",
      "tensor([[0.2630],\n",
      "        [0.7235],\n",
      "        [0.1395],\n",
      "        ...,\n",
      "        [0.8700],\n",
      "        [0.0705],\n",
      "        [0.0045]])\n",
      "tensor([[-2.7455,  0.4938,  0.2208,  0.6867],\n",
      "        [ 0.4126,  1.7031,  0.6554,  1.3090],\n",
      "        [-1.6596,  0.4897,  0.3812,  0.0976],\n",
      "        ...,\n",
      "        [ 1.8655,  0.9064,  0.6965,  1.2548],\n",
      "        [ 2.9318,  0.7008,  0.2329,  0.3484],\n",
      "        [ 2.6446,  0.6313,  0.5637,  1.0006]])\n",
      "tensor([[0.9925],\n",
      "        [0.1055],\n",
      "        [0.9030],\n",
      "        ...,\n",
      "        [0.0050],\n",
      "        [0.1375],\n",
      "        [0.0195]])\n",
      "tensor([[-0.9363,  2.0375,  0.6647,  0.8784],\n",
      "        [ 2.1768,  2.0689,  0.2002,  0.0957],\n",
      "        [ 2.2178,  1.0256,  0.8118,  1.7531],\n",
      "        ...,\n",
      "        [ 1.8106,  0.5004,  0.5103,  1.1521],\n",
      "        [-0.1882,  0.4675,  0.4198,  0.2818],\n",
      "        [ 1.7692,  1.6076,  0.4338,  0.7863]])\n",
      "tensor([[0.9210],\n",
      "        [0.0285],\n",
      "        [0.0010],\n",
      "        ...,\n",
      "        [0.1255],\n",
      "        [0.6285],\n",
      "        [0.0060]])\n",
      "tensor([[-2.1197,  1.4374,  0.8743,  1.9186],\n",
      "        [-1.6401,  2.4682,  0.8981,  0.8245],\n",
      "        [-2.2833,  0.9635,  0.2890,  1.5146],\n",
      "        ...,\n",
      "        [-2.2025,  0.7549,  0.2196,  0.9994],\n",
      "        [-2.9090,  1.2080,  0.8279,  0.4686],\n",
      "        [-1.4978,  0.4554,  0.1142,  1.9158]])\n",
      "tensor([[0.8000],\n",
      "        [0.8315],\n",
      "        [0.9985],\n",
      "        ...,\n",
      "        [0.9965],\n",
      "        [0.9270],\n",
      "        [0.9765]])\n"
     ]
    }
   ],
   "source": [
    "cnt = 0\n",
    "for xb, yb in jax_training_dataloader:\n",
    "    print(xb)\n",
    "    print(yb)\n",
    "    cnt += 1\n",
    "    if cnt > 10:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "# LOAD NETWORK\n",
    "jax_net = lanfactory.trainers.MLPJaxFactory(network_config=network_config, train=True)\n",
    "pickle.dump(\n",
    "    network_config,\n",
    "    open(\n",
    "        \"../data/jax_models/\" + MODEL + \"/\" + MODEL + \"_jax_cpn_network_config.pickle\",\n",
    "        \"wb\",\n",
    "    ),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "jax_trainer = lanfactory.trainers.ModelTrainerJaxMLP(\n",
    "    train_config=train_config,\n",
    "    model=jax_net,\n",
    "    train_dl=jax_training_dataloader,\n",
    "    valid_dl=jax_validation_dataloader,\n",
    "    pin_memory=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found folder:  ..\n",
      "Moving on...\n",
      "Found folder:  ../data\n",
      "Moving on...\n",
      "Found folder:  ../data/jax_models\n",
      "Moving on...\n",
      "Found folder:  ../data/jax_models/ddm\n",
      "Moving on...\n",
      "Epoch: 0 of 10\n",
      "Training - Step: 0 of 225 - Loss: 0.60183203\n",
      "Epoch 0/10 time: 10.109797954559326s\n",
      "Validation - Step: 0 of 225 - Loss: 0.3189134\n",
      "Epoch 0/10 time: 8.907364130020142s\n",
      "Epoch: 0 / 10, test_loss: 0.32131388783454895\n",
      "Epoch: 1 of 10\n",
      "Training - Step: 0 of 225 - Loss: 0.33258647\n",
      "Epoch 1/10 time: 9.075368881225586s\n",
      "Validation - Step: 0 of 225 - Loss: 0.31904417\n",
      "Epoch 1/10 time: 9.757975101470947s\n",
      "Epoch: 1 / 10, test_loss: 0.3192121684551239\n",
      "Epoch: 2 of 10\n",
      "Training - Step: 0 of 225 - Loss: 0.30525112\n",
      "Epoch 2/10 time: 8.732094049453735s\n",
      "Validation - Step: 0 of 225 - Loss: 0.3184372\n",
      "Epoch 2/10 time: 8.733233213424683s\n",
      "Epoch: 2 / 10, test_loss: 0.32003405690193176\n",
      "Epoch: 3 of 10\n",
      "Training - Step: 0 of 225 - Loss: 0.31566334\n",
      "Epoch 3/10 time: 8.568460702896118s\n",
      "Validation - Step: 0 of 225 - Loss: 0.3370237\n",
      "Epoch 3/10 time: 8.894043684005737s\n",
      "Epoch: 3 / 10, test_loss: 0.3194853961467743\n",
      "Epoch: 4 of 10\n",
      "Training - Step: 0 of 225 - Loss: 0.33415258\n",
      "Epoch 4/10 time: 8.90596604347229s\n",
      "Validation - Step: 0 of 225 - Loss: 0.31501514\n",
      "Epoch 4/10 time: 8.612452030181885s\n",
      "Epoch: 4 / 10, test_loss: 0.32039299607276917\n",
      "Epoch: 5 of 10\n",
      "Training - Step: 0 of 225 - Loss: 0.32927763\n",
      "Epoch 5/10 time: 9.014395952224731s\n",
      "Validation - Step: 0 of 225 - Loss: 0.31545284\n",
      "Epoch 5/10 time: 8.384047985076904s\n",
      "Epoch: 5 / 10, test_loss: 0.31847500801086426\n",
      "Epoch: 6 of 10\n",
      "Training - Step: 0 of 225 - Loss: 0.32268035\n",
      "Epoch 6/10 time: 8.719941139221191s\n",
      "Validation - Step: 0 of 225 - Loss: 0.32181156\n",
      "Epoch 6/10 time: 8.926185131072998s\n",
      "Epoch: 6 / 10, test_loss: 0.3206397294998169\n",
      "Epoch: 7 of 10\n",
      "Training - Step: 0 of 225 - Loss: 0.32026663\n",
      "Epoch 7/10 time: 9.069537878036499s\n",
      "Validation - Step: 0 of 225 - Loss: 0.32707316\n",
      "Epoch 7/10 time: 8.787101030349731s\n",
      "Epoch: 7 / 10, test_loss: 0.31883081793785095\n",
      "Epoch: 8 of 10\n",
      "Training - Step: 0 of 225 - Loss: 0.31532332\n",
      "Epoch 8/10 time: 8.644514083862305s\n",
      "Validation - Step: 0 of 225 - Loss: 0.32468107\n",
      "Epoch 8/10 time: 8.748849153518677s\n",
      "Epoch: 8 / 10, test_loss: 0.31939005851745605\n",
      "Epoch: 9 of 10\n",
      "Training - Step: 0 of 225 - Loss: 0.32577628\n",
      "Epoch 9/10 time: 8.57486081123352s\n",
      "Validation - Step: 0 of 225 - Loss: 0.32779425\n",
      "Epoch 9/10 time: 8.792699813842773s\n",
      "Epoch: 9 / 10, test_loss: 0.3195500373840332\n",
      "Saving training history to: ../data/jax_models/ddm//test_cpn_cpn_ddm__jax_training_history.csv\n",
      "Saving model parameters to: ../data/jax_models/ddm//test_cpn_cpn_ddm__train_state.jax\n",
      "Saving training config to: ../data/jax_models/ddm//test_cpn_cpn_ddm__train_config.pickle\n",
      "Saving training data details to: ../data/jax_models/ddm//test_cpn_cpn_ddm__data_details.pickle\n"
     ]
    }
   ],
   "source": [
    "train_state = jax_trainer.train_and_evaluate(\n",
    "    output_folder=\"../data/jax_models/\" + MODEL + \"/\",\n",
    "    output_file_id=MODEL,\n",
    "    run_id=\"test_cpn\",\n",
    "    wandb_on=False,\n",
    "    wandb_project_id=\"test_cpn\",\n",
    "    save_data_details=True,\n",
    "    verbose=1,\n",
    "    save_all=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'layer_sizes': [100, 100, 1], 'activations': ['tanh', 'tanh', 'linear'], 'train_output_type': 'logits'}\n"
     ]
    }
   ],
   "source": [
    "print(network_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Loaded Net\n",
    "jax_infer = lanfactory.trainers.MLPJaxFactory(\n",
    "    network_config=network_config,\n",
    "    train=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "passing through transform\n"
     ]
    }
   ],
   "source": [
    "my_state = jax_infer.load_state_from_file(\n",
    "    file_path=\"../data/jax_models/\" + MODEL + \"/test_cpn_cpn_ddm__train_state.jax\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "passing through transform\n"
     ]
    }
   ],
   "source": [
    "forward_pass, forward_pass_jitted = jax_infer.make_forward_partial(\n",
    "    seed=42,\n",
    "    input_dim=model_config[\"n_params\"] + 2,\n",
    "    state=\"../data/jax_models/\" + MODEL + \"/test_cpn_cpn_ddm__train_state.jax\",\n",
    "    add_jitted=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-3.0\n",
      "-2.877551020408163\n",
      "-2.7551020408163267\n",
      "-2.63265306122449\n",
      "-2.510204081632653\n",
      "-2.387755102040816\n",
      "-2.2653061224489797\n",
      "-2.142857142857143\n",
      "-2.020408163265306\n",
      "-1.8979591836734695\n",
      "-1.7755102040816326\n",
      "-1.653061224489796\n",
      "-1.5306122448979593\n",
      "-1.4081632653061225\n",
      "-1.2857142857142858\n",
      "-1.163265306122449\n",
      "-1.0408163265306123\n",
      "-0.9183673469387754\n",
      "-0.795918367346939\n",
      "-0.6734693877551021\n",
      "-0.5510204081632653\n",
      "-0.4285714285714288\n",
      "-0.30612244897959195\n",
      "-0.18367346938775508\n",
      "-0.06122448979591866\n",
      "0.06122448979591821\n",
      "0.18367346938775508\n",
      "0.30612244897959195\n",
      "0.4285714285714284\n",
      "0.5510204081632653\n",
      "0.6734693877551021\n",
      "0.7959183673469385\n",
      "0.9183673469387754\n",
      "1.0408163265306118\n",
      "1.1632653061224492\n",
      "1.2857142857142856\n",
      "1.408163265306122\n",
      "1.5306122448979593\n",
      "1.6530612244897958\n",
      "1.7755102040816322\n",
      "1.8979591836734695\n",
      "2.020408163265306\n",
      "2.1428571428571423\n",
      "2.2653061224489797\n",
      "2.387755102040816\n",
      "2.5102040816326525\n",
      "2.63265306122449\n",
      "2.7551020408163263\n",
      "2.8775510204081627\n",
      "3.0\n"
     ]
    }
   ],
   "source": [
    "import jax.numpy as jnp\n",
    "import numpy as np\n",
    "\n",
    "# Test parameters:\n",
    "v, a, z, t = 0.5, 1.5, 0.5, 0.3\n",
    "v = np.linspace(-3, 3, 50)\n",
    "\n",
    "# Comparison simulator run\n",
    "choice_p_list = []\n",
    "for v_tmp in v:\n",
    "    print(v_tmp)\n",
    "    sim_out = ssms.basic_simulators.simulator.simulator(\n",
    "        model=MODEL, theta=[v_tmp, a, z, t], n_samples=2000\n",
    "    )\n",
    "    choice_p_list.append(\n",
    "        np.sum(sim_out[\"choices\"] == -1.0) / sim_out[\"choices\"].shape[0]\n",
    "    )\n",
    "\n",
    "# Make input matric\n",
    "input_mat = jnp.zeros((50, 4))\n",
    "input_mat = input_mat.at[:, 0].set(jnp.array(v))\n",
    "input_mat = input_mat.at[:, 1].set(jnp.ones(50) * a)\n",
    "input_mat = input_mat.at[:, 2].set(jnp.ones(50) * z)\n",
    "input_mat = input_mat.at[:, 3].set(jnp.ones(50) * t)\n",
    "\n",
    "net_out = forward_pass_jitted(input_mat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[5.000e-04, 9.995e-01]])"
      ]
     },
     "execution_count": 83,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sim_out[\"choice_p\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(0.5, 0, 'v')"
      ]
     },
     "execution_count": 82,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGwCAYAAAB7MGXBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA+SklEQVR4nO3deXhU9aH/8fc5M5lJQhYSyE4gyCJL2GQz4gakICourda2VihWb/WiVfn1VvFWudZb6aKtt0pLpeLSloIbuCGCEbAoioC4sclmEEgIRJIQIJOZc35/BKOpgAkk+c5MPq/nOU8y35yT+cyYh/l4vmexXNd1ERERETHENh1ARERE2jaVERERETFKZURERESMUhkRERERo1RGRERExCiVERERETFKZURERESM8poO0BiO47B7924SExOxLMt0HBEREWkE13WpqqoiOzsb2z7+/o+IKCO7d+8mNzfXdAwRERE5CTt37qRTp07H/XlElJHExESg7sUkJSUZTiMiIiKNUVlZSW5ubv3n+PFERBn5YmomKSlJZURERCTCfNMhFjqAVURERIxSGRERERGjVEZERETEKJURERERMUplRERERIxSGRERERGjVEZERETEKJURERERMUplRERERIxSGRERERGjmlxG3njjDcaPH092djaWZbFgwYJv3GbZsmWcccYZ+P1+unfvzuOPP34SUUVERCQaNbmMVFdXM2DAAGbMmNGo9bdv385FF13EyJEjWbduHbfeeivXXXcdr776apPDioiISPRp8o3yxo0bx7hx4xq9/syZM+natSsPPPAAAL1792bFihX84Q9/YOzYsU19+mZVWnmEoOMSY1t4bAuvx8ZrW3g9FjG2jW2f+MY+4SoYcjhYE+RwbYgYj01cjIfYGA+eCH09IiIS3Vr8rr0rV66ksLCwwdjYsWO59dZbj7tNTU0NNTU19Y8rKytbJFvpw2M4rWYTAC51H9QOUHN0gS8/vN1jfO9i4X51zLLqtrGsujW+8rhuzMLFwgFct+6r49Zt67h139eN2biWBZZdd6dDy65fbPuL7y2CjkWt+8VXqHUsap26rw42DhYhPASx6x5bHrC9YHnBY2PZXiw7pm7MG4Nl+8Drw/bEYHt9WF4ftteHJ8ZH0PJx2I3hkOPlkBPDQSeG6pCXg6EYqkMeKoNegp5YkhMTSEuMpWOCj7REPx0T6pa0RD8dEnz4vZ4W+W8pIiKRq8XLSElJCRkZGQ3GMjIyqKys5PDhw8TFxX1tm+nTp3PPPfe0dDRi3RoSrCPN/4vdf/vaVF/0HvckfocFfNPnvQsEmxqqcWr3eKgmtm5xY6kmjko3lj3EUU0sNXY7DnmTORKTTI0vhZA/hVBcKsSlYsV3ILZdAkmxMXRPT6BXViLpibEtE1RERMJGi5eRkzF16lSmTJlS/7iyspLc3Nxmf56eN8+HYA3g4jgQdBxCR5eg4xIKffkVAPerLaPue8v9YsShNhgiUOtQEwpRUxsiEAxREwwROPp9IBjCtsDnsYj12vi84Ldt/F7weS18toU/xsZyHQJBh0BtkNpgiNpgsP531IaCBIIhcELEeW3iYizivBZxXoj1QqwHYmMs/B4LLw6OEyJYW0swGCAYrKW2NkgoWEswWEsoGCQUDBAKBnGCNbihAE6wFkIBCNXihuq+t5xaYtxafG4NPjdAjBsgxq3B6wTwOjV4nBpsNwRAjBWiPdW0p/qrO5YaCh1djtEDj7gxlJNIqZvKu24qFd403OQc4jt0JjU7j+zOPcjLO42YGF/z/BGIiIhxLV5GMjMzKS0tbTBWWlpKUlLSMfeKAPj9fvx+f0tHg6Ts+m9tIBo/3r54XS3+2kJBqK2GmoMQOHj0a1X9Y7emiiPVFRyp+hynuhz3cDnW4XK8Rz4npuYA/trP8bhBYq1asikn2ypnENR1vgNHl61Hn8q1KLNTKfHmsNPTiU+tXHbYOWxzO1HitifoUFciHZdOKXGcf3o6o3ql0z8nOWKPAxIRiWYtXkYKCgpYuHBhg7ElS5ZQUFDQ0k8trcnjBU8yxCYf88cWEHd0OSbXrSsxh8qheh9U7SZQXsznJZ9yZF8xVO4i7nAJKaH9xFgh0tz9pNXup1/tBw1+TaUbx1Y3hy1ONlvcHD7encejn3Xnj0VxdEzwcV7PumJydo+OJMfFNOc7ICIiJ6nJZeTgwYNs2bKl/vH27dtZt24dqampdO7cmalTp7Jr1y6efPJJAG644QYefvhhfv7zn3Pttdfy+uuv89RTT/Hyyy8336uQyGdZ4E+sW1K6AIPxARn/tprrhNi1q5g9xZ/gP7CNxINbSajaRrvKrcRWFZPEYQZZWxhkf/k36mCz2c1l1ZGerF7Xk/vW9qTUTmNwl1RG9UqnsE8G3dISWvPViojIV1iu6zbpEMlly5YxcuTIr41PnDiRxx9/nB/96Efs2LGDZcuWNdjmtttuY/369XTq1Im77rqLH/3oR41+zsrKSpKTk6moqCApKakpcaUtCdbA/q2wbxOUbYayjbBrDRz49GurlrgprHZ6ssbpyTtOb7zZ/bn8jE6MH5BNx4RWmCIUEWkDGvv53eQyYoLKiJySqhIofht2roKdb8Oe98FpeDpRsZPGImcYS9xhJHcv4PLBnRndO53YGJ2KLCJyslRGRI4ncAh2r4Wd70Dx27jb/4UVPFz/4z1uKq+GhrDcW0Bm/kguH9yFIV1SdPCriEgTqYyINFagGrYUwYYXcDa+gl17sP5H+9wkFocGsyr+XDoPHsflgzvTtWM7g2FFRCKHyojIyQjWwLZluOufJ7j+ZWICB+p/tNXJYlboIrZlX8z4wacxvn8W7eOj8YRwEZHmoTIicqpCtbBjBcGPnsf98BliglUAlLlJPBkcw1zGcEavbnz7jE6MPD0dn7fJ950UEYlqKiMizammCtY+SeitP+Gp+gyAw66Pp0Ln8WjoQirjOjG+fzYTz8qje7pOExYRAZURkZYRqoX1z8Ob/wcldRdcc7BYFBrKrOBFrKMH4/tnc/Oo7vTISDQcVkTELJURkZbkurDjX/DmH2HLkvrhRaGhTKudyF4rlQvzs7h5dHd6ZepvVkTaJpURkdZSuh5WzoD3/wluiMNWPPcFvss/QoU42FzQN5ObR3enb/axL5UvIhKtVEZEWlvJR/DiLbBrNQDbY3vzn5UT2eB0BqCwdwY/Hd2d/p3aGwwpItJ6VEZETHBCsHo2vHYPBKpwbS9F7a/kp3vGcMitu8z8xf2z+J9L+uqy8yIS9Rr7+a1zEUWak+2BYdfDTaug18VYTpDC8n/yftr/cEeP3dgWvPTBHsb84Q1eeH83EfD/AiIiLU57RkRa0saX4eWfQdVuAA50u4yflF3BO3vr/j/gW30y+NVl+aQnxZpMKSLSIrRnRCQc9Lqobi/J8BsAi/ZbFzDXmsq9BR5iPBZL1pdS+PvlPL16p/aSiEibpTIi0tL8iTDuN3B9EaSehlWxk2vW/wdFl0O/nGQqjwT5r2c+4EePvcvuA4e/+feJiEQZlRGR1pIzGK4rgs4FUFNB54XXsGDEDn5+wen4vDbLN5cx5g9vMOedYu0lEZE2RWVEpDXFp8I1CyD/CnCCeF6YzH86c1l48wgGdW7PwZogd87/kP/8x1oVEhFpM1RGRFpbTCx8exac87O6x2/8ju4rpvDMdYO56+I++Dw2r3xUwksf7DGbU0SklaiMiJhg2zD6LrjkYbC98OHTeP5+OT8+I5nJI7sDcN/CDRwKBA0HFRFpeSojIiadcQ1c/Qz4k6D4LXj0W9zQ36JTShx7Ko4wY+kW0wlFRFqcyoiIad1GwrWvQnIu7N+C//Ex3F9QA8CsN7bz6f5qwwFFRFqWyohIOMjoA9e9BlkD4dB+hq+4jh90qSQQcrj3pfWm04mItCiVEZFwkZgJkxZC3jlYgYPcU/1LsuwDvLZhL0s37TWdTkSkxaiMiIQTXzu46m/QoQcxB3fzbPuHiKWGX764nkDQMZ1ORKRFqIyIhJu4FLj6KYhLJfvQBv4UN5Md+6qY/eZ208lERFqEyohIOEo9Db43Bzw+RrnvcLt3Hg8VfUJp5RHTyUREmp3KiEi46lIAl84A4Abvi1wceo3pCzcYDiUi0vxURkTCWf/vwvlTAfhf72z2frCY1TvKDYcSEWleKiMi4e6826HflcRYIWbGPMis5xYRcnTfGhGJHiojIuHOsuCSh6nNGUaSdYj/PnA3z/3rPdOpRESajcqISCSIiSXmB3OpjOtEZ7uMnkt/wucVlaZTiYg0C5URkUjRrgPxP3qWKtoxgM3sfGwSuJquEZHIpzIiEkG8Gb0oLvwLta6H/gde45PFj5iOJCJyylRGRCJM37PHsyTjxwCkrfwle3YXG04kInJqVEZEItCoa+9lm6cr7TnI5sdv4khtyHQkEZGTpjIiEoFiY2Npd+WfCGFxXmA5f39yFq6OHxGRCKUyIhKhMnqdxZ7e1wIwrvi3/HPFesOJREROjsqISATrdPm9VMbmkGPtp3bxPbyrq7OKSARSGRGJZL52JF7xMADX2Iv589/+qZvpiUjEURkRiXBW91EE+30P23K5o3YGN//tbQJBx3QsEZFGUxkRiQLecdMJxXWgp72LM3f/jXte/Nh0JBGRRlMZEYkG8al4LvwtADd55/POqrd46t2dhkOJiDSOyohItMj/DvQYi88K8euYv3LXgg94f+cB06lERL6RyohItLAsuOgBXF8CQ+zNXMkSbvj7GsqrA6aTiYickMqISDRpn4s1ehoAU2PmQsUuZq/YbjiUiMiJqYyIRJuhP4ZOw2jHYe6Nmc2cdz7V5eJFJKypjIhEG9sDl/wR146h0PMeZxx5mxfW7TadSkTkuFRGRKJRem+sgskA3Op9ltkrtuneNSIStlRGRKLVWT/FjWlHvr2D7LI3WLltv+lEIiLHpDIiEq3adcAadh0AP/U+x2M6kFVEwpTKiEg0K7gZxxvLQHsbtZsXU7z/kOlEIiJfozIiEs0S0rCH1u0ducXzHI+/qb0jIhJ+VEZEot1ZPyXk8TPI3sJnaxZysCZoOpGISAMqIyLRLjEDe8gkAK53n+KZd4sNBxIRaUhlRKQNsEbcSsj2MdTezAcrXsJxdJqviIQPlRGRtiApC2fQBAC+e2gOyzbvNRxIRORLKiMibUTMuVMIWjGcaW/graIXTMcREamnMiLSViTncCT/+wCMLHmMzaVVhgOJiNQ5qTIyY8YM8vLyiI2NZfjw4axateqE6z/44IOcfvrpxMXFkZuby2233caRI0dOKrCInLyE0T8niJcRno95ffHzpuOIiAAnUUbmzZvHlClTmDZtGmvXrmXAgAGMHTuWvXuPPQc9Z84c7rjjDqZNm8aGDRt49NFHmTdvHnfeeecphxeRJmqfS3mPKwDot2Umn1cHDAcSETmJMvL73/+e66+/nkmTJtGnTx9mzpxJfHw8s2fPPub6b731FiNGjOAHP/gBeXl5jBkzhu9///sn3JtSU1NDZWVlg0VEmkfauDsI4mGE9SGvv/aS6TgiIk0rI4FAgDVr1lBYWPjlL7BtCgsLWbly5TG3Oeuss1izZk19+di2bRsLFy7kwgsvPO7zTJ8+neTk5PolNze3KTFF5ASs1K7szL0EgKz3H6I25BhOJCJtXZPKyL59+wiFQmRkZDQYz8jIoKSk5Jjb/OAHP+CXv/wlZ599NjExMXTr1o3zzz//hNM0U6dOpaKion7ZuXNnU2KKyDfIHv8Lgtic5azl7X8tMR1HRNq4Fj+bZtmyZdx333386U9/Yu3atTz33HO8/PLL3Hvvvcfdxu/3k5SU1GARkebjT+/OprRxAMS9db/hNCLS1jWpjHTs2BGPx0NpaWmD8dLSUjIzM4+5zV133cU111zDddddR79+/bj88su57777mD59Oo6j3cMipmRe/N+EXIshgVVsfu9fpuOISBvWpDLi8/kYPHgwRUVF9WOO41BUVERBQcExtzl06BC23fBpPB4PAK6rS1KLmNKhS1/eS647/qt62YNmw4hIm9bkaZopU6Ywa9YsnnjiCTZs2MCNN95IdXU1kybV3YhrwoQJTJ06tX798ePH8+c//5m5c+eyfft2lixZwl133cX48ePrS4mImBF79n8C0OfAco5U7jOcRkTaKm9TN7jqqqsoKyvj7rvvpqSkhIEDB7Jo0aL6g1qLi4sb7An5xS9+gWVZ/OIXv2DXrl2kpaUxfvx4fvWrXzXfqxCRk9Jn8Pl88koePdjBx689St9v3246koi0QZYbAXMllZWVJCcnU1FRoYNZRZrZ4sf/lzE7fseumDxy7lwHlmU6kohEicZ+fuveNCJtXLdRkzjs+sip3UHlJ2+ZjiMibZDKiEgb161zDm/GngPA3uV/MZxGRNoilRER4XD+DwHI3bUIjlQYTiMibY3KiIgw/NxxfOLk4KeG/W//w3QcEWljVEZEhPTkOFalXgxA6N3HzYYRkTZHZUREAEg+cwI1rpf06k24u94zHUdE2hCVEREBYOSgXixxhwOw/41HDKcRkbZEZUREAGjn97KjyxUAJHyyAGoOmg0kIm2GyoiI1MsfcRHbnQxinUMEP3zWdBwRaSNURkSk3tk90njJe/TmeSsfNZxGRNoKlRERqef12NTkf59a10Py/veh9GPTkUSkDVAZEZEGxg7rzxJnMACBVY8ZTiMibYHKiIg0kJ+TxPKEcXUPPpgHtYfNBhKRqKcyIiINWJZF7pCL+MztiK+2Eta/YDqSiEQ5lRER+ZpLB+UyL3g+AIFVs82GEZGopzIiIl+TmxrP5uxLCbkWvl1vQ9lm05FEJIqpjIjIMZ03ZCCvO4PqHqx9wmwYEYlqKiMickwX9cviGXc0AMH35kCwxnAiEYlWKiMickzJ8THYPb9FiZuC90g5bHzZdCQRiVIqIyJyXJee0ZmnQucB4K79m+E0IhKtVEZE5LhG9krnVe+ougfbl0P1frOBRCQqqYyIyHH5vR769z+DD508LDcEG180HUlEopDKiIic0CUDsnk5dCYA7kfzDacRkWikMiIiJzQkL4XlMSPqHux4A6r3mQ0kIlFHZURETijGY9Pj9H584HTFch3YoMvDi0jzUhkRkW80und6/VQNH2uqRkSal8qIiHyj83qm8YozHAB3xwo4WGY4kYhEE5UREflG7eN9ZHbpxTrnNE3ViEizUxkRkUYZpakaEWkhKiMi0iije6V/OVXz6ZtQVWo4kYhEC5UREWmU7ukJ2CldWOd001SNiDQrlRERaRTLshjVK52X6qdqFhjNIyLRQ2VERBptdO90XgkNAzRVIyLNR2VERBptWNdUDvgyWet0x8LVVI2INAuVERFpNL/Xwzk90nRWjYg0K5UREWmSUb3TWRiqO6uGT9+Cyj1mA4lIxFMZEZEmGXl6OiVWB9Y4PUBTNSLSDFRGRKRJ0hL9DOjUXlM1ItJsVEZEpMlG90pn4dGzaih+Gyp3mw0kIhFNZUREmmxU73RK6MAa93TAhfWaqhGRk6cyIiJN1icriazkWF4MHj2QVVM1InIKVEZEpMm+uBrrFxdAY+fbULHLbCgRiVgqIyJyUkb3TqeUVN63etcNrH/ebCARiVgqIyJyUs7q1pHYGJvnAkf3jmiqRkROksqIiJyU2BgPI7p15JXQMFws+GwVVHxmOpaIRCCVERE5aaN6p7OXFNbH9K0b0FSNiJwElREROWmjeqUDMO/QkLqBj54zmEZEIpXKiIictKzkOPpmJ/FKaGjdVM2u1bpXjYg0mcqIiJyS0b3SKSOF7bFHz6rZvMhsIBGJOCojInJKRvXOAOCFwwPqBja9YjCNiEQilREROSX9c5LpmODj5cCguoFtyyBQbTSTiEQWlREROSW2bTHy9HQ+cXMo9+VAqAa2LjUdS0QiiMqIiJyy0b0zAIvXnDPqBjRVIyJNoDIiIqfsnB4d8Xlt5h/qXzeweRE4IbOhRCRiqIyIyClr5/dybo803nVO54gnEQ7tg8/eNR1LRCKEyoiINIsL8jMJ4uUt+4upmoVmA4lIxDipMjJjxgzy8vKIjY1l+PDhrFq16oTrHzhwgMmTJ5OVlYXf76dnz54sXKh/qESiybd6Z+C1LZ6rPjpVo+NGRKSRmlxG5s2bx5QpU5g2bRpr165lwIABjB07lr179x5z/UAgwLe+9S127NjBM888w6ZNm5g1axY5OTmnHF5EwkdyfAwF3Tqw3BlAyPLCvs2wb4vpWCISAZpcRn7/+99z/fXXM2nSJPr06cPMmTOJj49n9uzZx1x/9uzZlJeXs2DBAkaMGEFeXh7nnXceAwYMOO5z1NTUUFlZ2WARkfA3Lj+LKuL5wNuvbmCz9o6IyDdrUhkJBAKsWbOGwsLCL3+BbVNYWMjKlSuPuc0LL7xAQUEBkydPJiMjg/z8fO677z5CoeMfaT99+nSSk5Prl9zc3KbEFBFDxvTNwLb48qwaTdWISCM0qYzs27ePUChERkZGg/GMjAxKSkqOuc22bdt45plnCIVCLFy4kLvuuosHHniA//3f/z3u80ydOpWKior6ZefOnU2JKSKGdEzwMzQvlaLQ0YNYi1fCoXKzoUQk7LX42TSO45Cens4jjzzC4MGDueqqq/jv//5vZs6cedxt/H4/SUlJDRYRiQzj8jPZRRrbvaeB68Ani01HEpEw16Qy0rFjRzweD6WlpQ3GS0tLyczMPOY2WVlZ9OzZE4/HUz/Wu3dvSkpKCAQCJxFZRMLZBflZALx45Isb5+nMORE5sSaVEZ/Px+DBgykqKqofcxyHoqIiCgoKjrnNiBEj2LJlC47j1I9t3ryZrKwsfD7fScYWkXCVmRzLoM7tWRIaXDewpQiCNWZDiUhYa/I0zZQpU5g1axZPPPEEGzZs4MYbb6S6uppJkyYBMGHCBKZOnVq//o033kh5eTm33HILmzdv5uWXX+a+++5j8uTJzfcqRCSsjMvP5CM3j3K7AwQOwo5/mY4kImHM29QNrrrqKsrKyrj77rspKSlh4MCBLFq0qP6g1uLiYmz7y46Tm5vLq6++ym233Ub//v3Jycnhlltu4fbbb2++VyEiYWVcfhb3LdzIotqB/MBTVHdWTffCb95QRNoky3Vd13SIb1JZWUlycjIVFRU6mFUkQlz0x3+RVrKcx32/g6QcuO1jsCzTsUSkFTX281v3phGRFjEuP5OVTl+OWLFQuQtKPjAdSUTClMqIiLSIC/KzqMHH8tDRC6Bt1Fk1InJsKiMi0iK6pyfQPT2BxUHdxVdETkxlRERazLj8TJY6A3Gw66ZpKj4zHUlEwpDKiIi0mAvyMyknibVuj7oB3atGRI5BZUREWkyfrCQ6p8Z/ZapGZUREvk5lRERajGVZjMvP5DXn6NVYt78BRyrNhhKRsKMyIiIt6oL8TLa52Wx3s8Cpha2vm44kImFGZUREWtSATu3JSo5lcUhTNSJybCojItKibNtibN9MXvvixnmfvAqhoNlQIhJWVEZEpMWNy89krduDz0mEw5/DzndMRxKRMKIyIiItbkheKikJcbweGlA3sHmR2UAiElZURkSkxXlsizF9M3n9i+NGNr9qNpCIhBWVERFpFePyM/mX048gHti3Ccq3mY4kImFCZUREWsWZp3XAimvPqtDpdQObF5sNJCJhQ2VERFpFjMdmdO90ipwvpmp03IiI1FEZEZFWM6ZPJkXOIADcHSt0NVYRAVRGRKQVnduzI3s8OWx1srCcWti21HQkEQkDKiMi0mrifV7O6ZHG60f3juisGhEBlRERaWVj+mY0LCOOYzaQiBinMiIirWp0r3TWuKdT6cbDoX2we63pSCJimMqIiLSqDgl+Buals9zpXzegs2pE2jyVERFpdWP6ZFBUfxdflRGRtk5lRERa3Zg+mSx3+hNyLSj9ECo+Mx1JRAxSGRGRVte5QzwZmTmsdXvUDeisGpE2TWVERIwYqxvnichRKiMiYsSYvhlfXo11+3IIHDKcSERMURkRESP6ZCVRndSDnU4aVvAIbF9uOpKIGKIyIiJGWJbFmPwv71WjU3xF2i6VERExZkyfzPqrsbqbXwXXNZxIRExQGRERY4bmpbDR359q149VtQdKPjAdSUQMUBkREWO8HptzeueywulXN6CzakTaJJURETFqTN8MXnPqTvF1ddyISJukMiIiRp3bI42Vdl0ZsXatgapSw4lEpLWpjIiIUXE+D7179GCdc1rdwCeLzQYSkVanMiIixo3pk/GVq7FqqkakrVEZERHjCntnsNStO8XX2fo6BGsMJxKR1qQyIiLGpbTz0a7LGZS4Kdi1h2DHCtORRKQVqYyISFgY0zeL10O6GqtIW6QyIiJh4Vt9vrxxXmjTIl2NVaQNURkRkbCQmxpPeXoBR9wYPBXFULbRdCQRaSUqIyISNs7L78JbTt+6B5sWmg0jIq1GZUREwsaYPpkscQYD4Gx42XAaEWktKiMiEjZ6ZyWyPvEsAOzda6CqxHAiEWkNKiMiEjYsy2Jofh/ec7rXDWx6xWwgEWkVKiMiElbG9ctiSahuqiakqRqRNkFlRETCyqDc9rwXXzdVY21fBjVVZgOJSItTGRGRsGLbFr36DWG7k4Ht1MKWItORRKSFqYyISNi5sH82S5whgKZqRNoClRERCTuDO6ew2n8mAM7mVyFUaziRiLQklRERCTu2bZHd7zz2u4nEBCqgeKXpSCLSglRGRCQsjevfiaLQGYCmakSincqIiISlIXmpvOOrm6oJfPySbpwnEsVURkQkLHlsi+T8MRx2fcRVfwalH5mOJCItRGVERMJW4YA8Vjj9AE3ViEQzlRERCVvD8lJ5yzsMgEMfvmA4jYi0lJMqIzNmzCAvL4/Y2FiGDx/OqlWrGrXd3LlzsSyLyy677GSeVkTaGK/Hxtv7QhzXIrH8Y6j4zHQkEWkBTS4j8+bNY8qUKUybNo21a9cyYMAAxo4dy969e0+43Y4dO/jZz37GOeecc9JhRaTtOW9QH1a7PQFN1YhEqyaXkd///vdcf/31TJo0iT59+jBz5kzi4+OZPXv2cbcJhUJcffXV3HPPPZx22mmnFFhE2pYzT0tlhaduqqbqfU3ViESjJpWRQCDAmjVrKCws/PIX2DaFhYWsXHn8ixL98pe/JD09nR//+MeNep6amhoqKysbLCLSNnk9NqEe4wBI3LMSDh8wG0hEml2Tysi+ffsIhUJkZGQ0GM/IyKCkpOSY26xYsYJHH32UWbNmNfp5pk+fTnJycv2Sm5vblJgiEmWGDxnGJ04OHkI4nywxHUdEmlmLnk1TVVXFNddcw6xZs+jYsWOjt5s6dSoVFRX1y86dO1swpYiEu4JuHXjDHgpA+ZoFZsOISLPzNmXljh074vF4KC0tbTBeWlpKZmbm19bfunUrO3bsYPz48fVjjuPUPbHXy6ZNm+jWrdvXtvP7/fj9/qZEE5EoFuOxOdR1LGxfQMLOpRAMgNdnOpaINJMm7Rnx+XwMHjyYoqKi+jHHcSgqKqKgoOBr6/fq1YsPP/yQdevW1S+XXHIJI0eOZN26dZp+EZFGyx8+ir1ue2Kdapzt/zIdR0SaUZP2jABMmTKFiRMnMmTIEIYNG8aDDz5IdXU1kyZNAmDChAnk5OQwffp0YmNjyc/Pb7B9+/btAb42LiJyIiO6p/OCNZgrKKJs9XNk9BhtOpKINJMml5GrrrqKsrIy7r77bkpKShg4cCCLFi2qP6i1uLgY29aFXUWkefm8Np/njoGdRcRte7XuxnmWZTqWiDQDy3XD/1aYlZWVJCcnU1FRQVJSkuk4ImLI0o+KGfb0ENpZNTjXLcXudIbpSCJyAo39/NYuDBGJGGf1yuFNBgJQ+u6zZsOISLNRGRGRiOH3eijLqTtWxN78iuE0ItJcVEZEJKJkD72MoGuTcXgrzv7tpuOISDNQGRGRiFKQ35219AJg9zvPGE4jIs1BZUREIkpsjIfijLqpGnf984bTiEhzUBkRkYiTOuQKHNci9+CHuAd0uwiRSKcyIiIRZ/iAfFYfnarZ9848w2lE5FSpjIhIxGnn97K5QyEAzkc6xVck0qmMiEhESjjj24Rci4yq9VC+zXQcETkFKiMiEpHOHtiXlW4fACpXP204jYicCpUREYlIHRP8fNi+bqom+KGmakQimcqIiESshIGXU+t6SK3aBPs+MR1HRE6SyoiIRKxzB/RkhZMPwJH3NFUjEqlURkQkYnXp0I41iSMBCHygq7GKRCqVERGJaLH5l1Djekmq2gql603HEZGToDIiIhHt/AHdecMZAEBQe0dEIpLKiIhEtL7ZSazwnwNA4P1nwXUNJxKRplIZEZGIZlkW/r4XccSNIf7gDij5wHQkEWkilRERiXjn9zuN151BADgfPmc4jYg0lcqIiES8oV1Ted1zNgC1HzyjqRqRCKMyIiIRL8Zj4zl9LNWuH//Bz2DXWtORRKQJVEZEJCqM7NeFIucMAFzdyVckoqiMiEhUOLdnGos4C4Dgh8+B4xhOJCKNpTIiIlEh3ufFOa2QSjeOmOo98Nkq05FEpJFURkQkaozql8sSZ3Ddg490Vo1IpFAZEZGoMbpXOgudAgBCH80HJ2Q4kYg0hsqIiESNDgl+DuWeywG3HZ5De+HTN01HEpFGUBkRkagyum8nXg0NrXugqRqRiKAyIiJRZUyfTF48OlXjrH8BQkHDiUTkm6iMiEhU6dwhnvK04exzk7AP74fty01HEpFvoDIiIlGnMD+HRZqqEYkYKiMiEnXG9MnghVDdBdDc9Qug5qDZQCJyQiojIhJ1+mYnsStpENucTKzAQfhYe0dEwpnKiIhEHcuy+FbfTOaFRtYNrHnCbCAROSGVERGJShf2y+LZ0LnUuh7YtRpKPjIdSUSOQ2VERKLS0LwUOmR0YvEXl4dfq70jIuFKZUREopJlWfxoRB5zQ6MAcD+YB7WHDacSkWNRGRGRqHXZwBw+8g9ip5OGdaQC1j9vOpKIHIPKiIhErTifh6uG5TEvdH7dgA5kFQlLKiMiEtWuKejCc+75hFwLit+Css2mI4nIv1EZEZGoltM+joF9e/O6M6huQAeyioQdlRERiXo/Oqsr/zx6IKuzbg4EawwnEpGvUhkRkag3NC+Fsoxz2OOmYh8uh40vmY4kIl+hMiIiUc+yLCaM6MZTofMAcHQgq0hYURkRkTZh/IBsFvvG4LgW9vblUL7NdCQROUplRETahNgYD6POHMy/nH51A2v/ZjaQiNRTGRGRNuOHZ3bhKafuQNbaNX+DUK3hRCICKiMi0oZkJMUS0/diytxkYg6XweZXTUcSEVRGRKSNmXB2d54JnQtAYNVsw2lEBFRGRKSNGZTbnnVplwAQs/11OLDTcCIRURkRkTbFsizGnTuCt0J9sHAJ6UBWEeNURkSkzbmwXxYv+8YAUPPuk+CEDCcSadtURkSkzfF5bTKHX8HnbgLxh/fAliLTkUTaNJUREWmTvlfQk/lO3YGsFW/+1XAakbZNZURE2qS0RD97e1wFQOKnS2D/VsOJRNoulRERabMuGjWS10MDsXGoKvqd6TgibdZJlZEZM2aQl5dHbGwsw4cPZ9WqVcddd9asWZxzzjmkpKSQkpJCYWHhCdcXEWkt/Tol82bOJADi1z+l03xFDGlyGZk3bx5Tpkxh2rRprF27lgEDBjB27Fj27t17zPWXLVvG97//fZYuXcrKlSvJzc1lzJgx7Nq165TDi4icqu9c8m1WOH3xEGL/q781HUekTbJc13WbssHw4cMZOnQoDz/8MACO45Cbm8vNN9/MHXfc8Y3bh0IhUlJSePjhh5kwYUKjnrOyspLk5GQqKipISkpqSlwRkW/08OzHuKn4VgLE4JvyISRlmY4kEhUa+/ndpD0jgUCANWvWUFhY+OUvsG0KCwtZuXJlo37HoUOHqK2tJTU19bjr1NTUUFlZ2WAREWkpl176XVY7p+Ojll0LdeyISGtrUhnZt28foVCIjIyMBuMZGRmUlJQ06nfcfvvtZGdnNyg0/2769OkkJyfXL7m5uU2JKSLSJLkd2rGx508A6LDx7zhVZYYTibQtrXo2za9//Wvmzp3L/PnziY2NPe56U6dOpaKion7ZuVMHlYlIyxp36dV85HYjlhq2vqhjR0RaU5PKSMeOHfF4PJSWljYYLy0tJTMz84Tb3n///fz6179m8eLF9O/f/4Tr+v1+kpKSGiwiIi2pQ2Is2/vcCED25r9Re7DccCKRtqNJZcTn8zF48GCKir68dLLjOBQVFVFQUHDc7X77299y7733smjRIoYMGXLyaUVEWtCoSyfyCZ1px2E+XqBjR0RaS5OnaaZMmcKsWbN44okn2LBhAzfeeCPV1dVMmlR3rv6ECROYOnVq/fq/+c1vuOuuu5g9ezZ5eXmUlJRQUlLCwYMHm+9ViIg0g3axPvb0nwxA1y1PUl35ueFEIm1Dk8vIVVddxf3338/dd9/NwIEDWbduHYsWLao/qLW4uJg9e/bUr//nP/+ZQCDAFVdcQVZWVv1y//33N9+rEBFpJmde/GM+tXJI5iDrntO/UyKtocnXGTFB1xkRkda09sU/ccaaqex3k+DWD+iQkmI6kkhEapHrjIiItAUDL7iOEjuDDlYl7z77oOk4IlFPZURE5N/YMT6qhtwMwKCdT7Bzr44dEWlJKiMiIsfQY8x/sN/TkQzrc95+7o+m44hENZUREZFj8fo5MuwmAAr2PMmGz/YbDiQSvVRGRESOI2fUDVR6Uuhk7ePN+X8yHUckaqmMiIgcT0wcweF11x0ZVfZ3Xv94l+FAItFJZURE5ARSz7uBQ54kTrNLWPfc76g4XGs6kkjUURkRETkRfyIxY/4HgP8I/pOH5y81m0ckCqmMiIh8g5ihk6hKH0yCdYQhG37DG5vLTEcSiSoqIyIi38S2SfzOw4TwMNazmoVP/5WDNUHTqUSihsqIiEhjZPTBOavuQmi3BB7hgRfXGA4kEj1URkREGilm5B0cSehMllVO7rrf89aWfaYjiUQFlRERkcaKiSP2sgcBmOh5ldlPP0e1pmtETpnKiIhIU3QfTW2fK/BYLrccnsEDiz42nUgk4qmMiIg0UcyF06mNSaKfvQNWzWLV9nLTkUQimsqIiEhTJaQTc8G9APw/71M88HQRhwMhw6FEIpfKiIjIyRg0gWCnM2ln1XBd1Z/4/eKNphOJRCyVERGRk2HbeC/5Pxw7hm951rJz5dOsLf7cdCqRiKQyIiJystJ7YY+4BYBp3ieY9tRKjtRqukakqVRGREROxbk/I9S+K1lWOd8+8Dj3LdxgOpFIxFEZERE5FTFxeMb/AYCJnsWsf/tVXnx/t+FQIpFFZURE5FR1GwkDvo9tucz0Pcgfn32NrWUHTacSiRgqIyIizeGiB3Az+tHRquQhfsv/+9sKne4r0kgqIyIizcHXDusH8wi1S6eXvZObPv81dy9433QqkYigMiIi0lySc/B8fy6Ox0+h5z26f3A/T63eaTqVSNhTGRERaU6dBmNf/mcAfuJ9mXXPP8SGPZWGQ4mEN5UREZHmlv8d3HNvB+B/7L8y88m/UXWk1nAokfClMiIi0gKs8+8gcPql+KwQ0w7dx+/mvorruqZjiYQllRERkZZg2/i+M5Pqjv1JtQ7yw20/559vfGQ6lUhYUhkREWkpvnjaTXiKan8aPe1d5BRNZt2n+0ynEgk7KiMiIi0pKYv4CU8TsPycZ7/P5idv4cChgOlUImFFZUREpIVZOYMIXlp3hs13Qy+x4OH/ovxgjeFUIuFDZUREpBXED/wOZcN+DsCPDj3OG/83iT2f65LxIqAyIiLSatLG3UnZiGk4WFxW+zKfPHQ5O/aUmY4lYpzKiIhIa7Es0r41hc8vfIQAMZzrrOLgX8axaes208lEjFIZERFpZR2GfZfq7z1HhZVEPp8Q/7cL+PD9d03HEjFGZURExICUXudiX7eYEk8muZSS+9xlvPfmItOxRIxQGRERMSQxpzfJNy1ni68X7a2D9Fn8Q9YsnG06lkirUxkRETEoLiWTzrcV8X67EfitWgavuo335t4LunS8tCEqIyIihvniEsi/7QXe6nAFAIM23s9Hf/ohtQfLDScTaR0qIyIiYcDj9VIweRZFnW/FcS3yy17i4AOD2P3mP7SXRKKeyoiISJiwbJvR197DinOeZBs5pLgHyF7yn3z68MWEyj81HU+kxaiMiIiEmXMLL6HdT1eyIHkCNa6XLvtXEHhoGPte+wM4IdPxRJqdyoiISBjKSE3m0lv/yOvnPcsatxdx7hE6rvgfyh48G2f3+6bjiTQrlRERkTBlWRbjRp1P5q2v82j7W6h040mrXI/7yPlUvngnBA6ZjijSLCzXDf8joyorK0lOTqaiooKkpCTTcUREWp3rujy7fA2JS+9krPUOANW+NKyCycQX/Bhi9W+jhJ/Gfn6rjIiIRJDi/Yf4599n8sPyh8mx9gNw2G5H2elXk3PBbXiSsw0nFPmSyoiISJRyHJdn3tnKrn89wfiDz9Dd3g1ALV42pY8j9Vv/j+wegwynFFEZERFpEz767HM+eH0uvbY9zhlsrB9f4x9O9dDJDD3nIuL8XoMJpS1TGRERaUOO1IZY/a9F+N+dweBDb2Fbdf+0f0Q3yruOZ+AFPyIpo6vhlNLWqIyIiLRRpds/Yu+r99Oz5CX81NaP70nIp/2QK4kb+G1o39lgQmkrVEZERNq4UNVe1i95Avfj+eQH19fvLQEIZg7C2+9y6HMZpHQxF1KimsqIiIgAdQe8Ll39ARuXzmFw9XKGWRsbFBOyBkLXc6DzWdD5TIhPNZZVoovKiIiINOA4Lks2lPK3JavIK3udi+x3GGZvwGM1/Bhw03phdS6AzgXQpQCSc8GyDKWWSKYyIiIix+S6LkUb9vLH1z9h92fFnGu/zxB7E0PtzfSwd31t/YP+DCrSBmNl9ichty+JnfKxUvLA1kW85cRURkRE5IRc12XZ5jIWfVhCcfkhdn5+iJqKvQxiI0PszQyzN9LX2kGM9fWb8x3BR6mvMwfanUYgtSeejN4kdMonO68nCfHxBl6NhKMWLSMzZszgd7/7HSUlJQwYMICHHnqIYcOGHXf9p59+mrvuuosdO3bQo0cPfvOb33DhhRc2+vlURkREWkcw5LCn4gg7Pz/EZ+WHKdm3H8+etaTsf48Oh7eTGyymm7Ubv1V7zO0d16LcSuZATDpH4jJxEnPwpXYiKSOPDtld8afmQkIGeP2t/MrEhBYrI/PmzWPChAnMnDmT4cOH8+CDD/L000+zadMm0tPTv7b+W2+9xbnnnsv06dO5+OKLmTNnDr/5zW9Yu3Yt+fn5zfpiRESkZdUEQ+zaf5CyzzZzeNfHULaJ+AOfkHpoGzmhncQRaNzvsWI55E2iJiaZoK89odgUiEvF0y4Fb0JHfO1SsGITsP2J2P522P4EvHGJePwJeGITsHwJmiaKAC1WRoYPH87QoUN5+OGHAXAch9zcXG6++WbuuOOOr61/1VVXUV1dzUsvvVQ/duaZZzJw4EBmzpzZrC9GREQMcl0qy0soKd7C5yU7qC77lNCBXcRU7yahppQ0Zz+ZVjl+K9gsT3fY9RGwYgjgI2jFUGv5CNo+gpYfx44h5PHj2j4cOwbX9uBaMbi2F9eOAc/Rr7YX7BgsjxdsD5btwbK9WLYH2+PF8njrx2yPB9v2YHs8eDw2lu3Bc3Q927bxeDxYlo1lW4Bdd9CvZeFigWXDF18t6+j3Fi585WfW1766fHHgsEXdt0cfN1j3K+NHv7pfHT/ewcf/tm1Kdjf8cQmn+p+lgcZ+fjfpGsGBQIA1a9YwderU+jHbtiksLGTlypXH3GblypVMmTKlwdjYsWNZsGDBcZ+npqaGmpqa+seVlZVNiSkiIiZYFkkdskjqkAWc0+BHrutSXh3go31VlO/fz6HKfQQq9xGq3k+ouhzr8Od4az4npraC2NoDxDnVxLlHiOMI7ThCO+sI8Ue//+LsnzgrcHRPTDW41C1Oa7/o6LHx4vn0GjLKyHM3qYzs27ePUChERkZGg/GMjAw2btx4zG1KSkqOuX5JSclxn2f69Oncc889TYkmIiJhzLIsOiT46ZDgh7yOwOmN2s51XUKOS9BxqQ05VIUcamsO49RUETxcTSBwmNqaL5dg4DChwBFCgcOEamtwag9DKAihQN1XJwhOLZZTW/c1FMRyguCGwA2CE8JynaNfv1gcLDcIrgOuC66DhYPlulhu6OjjunGbuu8tXOyj+zUsHOr2lTgNHn+x3vHHvpi4aPj433/+5X6PL3/+xfiXv+Ob18f2NOq/SUsIy7snTZ06tcHelMrKSnJzcw0mEhEREyzLwuux8HogNuboh2U7P9DeZKyo1N7gczepjHTs2BGPx0NpaWmD8dLSUjIzM4+5TWZmZpPWB/D7/fj9OtJaRESkLWjSocg+n4/BgwdTVFRUP+Y4DkVFRRQUFBxzm4KCggbrAyxZsuS464uIiEjb0uRpmilTpjBx4kSGDBnCsGHDePDBB6murmbSpEkATJgwgZycHKZPnw7ALbfcwnnnnccDDzzARRddxNy5c1m9ejWPPPJI874SERERiUhNLiNXXXUVZWVl3H333ZSUlDBw4EAWLVpUf5BqcXEx9lfO/T7rrLOYM2cOv/jFL7jzzjvp0aMHCxYsaPQ1RkRERCS66XLwIiIi0iIa+/mty9eJiIiIUSojIiIiYpTKiIiIiBilMiIiIiJGqYyIiIiIUSojIiIiYpTKiIiIiBilMiIiIiJGqYyIiIiIUU2+HLwJX1wktrKy0nASERERaawvPre/6WLvEVFGqqqqAMjNzTWcRERERJqqqqqK5OTk4/48Iu5N4zgOu3fvJjExEcuymu33VlZWkpuby86dO3XPm2+g96pp9H41nt6rxtN71Xh6rxqvJd8r13WpqqoiOzu7wU10/11E7BmxbZtOnTq12O9PSkrSH2sj6b1qGr1fjaf3qvH0XjWe3qvGa6n36kR7RL6gA1hFRETEKJURERERMapNlxG/38+0adPw+/2mo4Q9vVdNo/er8fReNZ7eq8bTe9V44fBeRcQBrCIiIhK92vSeERERETFPZURERESMUhkRERERo1RGRERExCiVka+45JJL6Ny5M7GxsWRlZXHNNdewe/du07HCzo4dO/jxj39M165diYuLo1u3bkybNo1AIGA6Wlj61a9+xVlnnUV8fDzt27c3HSeszJgxg7y8PGJjYxk+fDirVq0yHSksvfHGG4wfP57s7Gwsy2LBggWmI4Wt6dOnM3ToUBITE0lPT+eyyy5j06ZNpmOFpT//+c/079+//mJnBQUFvPLKK0ayqIx8xciRI3nqqafYtGkTzz77LFu3buWKK64wHSvsbNy4Ecdx+Mtf/sLHH3/MH/7wB2bOnMmdd95pOlpYCgQCXHnlldx4442mo4SVefPmMWXKFKZNm8batWsZMGAAY8eOZe/evaajhZ3q6moGDBjAjBkzTEcJe8uXL2fy5Mm8/fbbLFmyhNraWsaMGUN1dbXpaGGnU6dO/PrXv2bNmjWsXr2aUaNGcemll/Lxxx+3fhhXjuv55593LctyA4GA6Shh77e//a3btWtX0zHC2mOPPeYmJyebjhE2hg0b5k6ePLn+cSgUcrOzs93p06cbTBX+AHf+/PmmY0SMvXv3uoC7fPly01EiQkpKivvXv/611Z9Xe0aOo7y8nH/84x+cddZZxMTEmI4T9ioqKkhNTTUdQyJEIBBgzZo1FBYW1o/Ztk1hYSErV640mEyiTUVFBYD+ffoGoVCIuXPnUl1dTUFBQas/v8rIv7n99ttp164dHTp0oLi4mOeff950pLC3ZcsWHnroIX7yk5+YjiIRYt++fYRCITIyMhqMZ2RkUFJSYiiVRBvHcbj11lsZMWIE+fn5puOEpQ8//JCEhAT8fj833HAD8+fPp0+fPq2eI+rLyB133IFlWSdcNm7cWL/+f/3Xf/Hee++xePFiPB4PEyZMwG0jF6lt6nsFsGvXLi644AKuvPJKrr/+ekPJW9/JvFci0romT57MRx99xNy5c01HCVunn34669at45133uHGG29k4sSJrF+/vtVzRP3l4MvKyti/f/8J1znttNPw+XxfG//ss8/Izc3lrbfeMrLbqrU19b3avXs3559/PmeeeSaPP/44th313bbeyfxdPf7449x6660cOHCghdOFv0AgQHx8PM888wyXXXZZ/fjEiRM5cOCA9kiegGVZzJ8/v8H7Jl9300038fzzz/PGG2/QtWtX03EiRmFhId26deMvf/lLqz6vt1WfzYC0tDTS0tJOalvHcQCoqalpzkhhqynv1a5duxg5ciSDBw/msccea1NFBE7t70rA5/MxePBgioqK6j9UHcehqKiIm266yWw4iWiu63LzzTczf/58li1bpiLSRI7jGPnMi/oy0ljvvPMO7777LmeffTYpKSls3bqVu+66i27durWJvSJNsWvXLs4//3y6dOnC/fffT1lZWf3PMjMzDSYLT8XFxZSXl1NcXEwoFGLdunUAdO/enYSEBLPhDJoyZQoTJ05kyJAhDBs2jAcffJDq6momTZpkOlrYOXjwIFu2bKl/vH37dtatW0dqaiqdO3c2mCz8TJ48mTlz5vD888+TmJhYfwxScnIycXFxhtOFl6lTpzJu3Dg6d+5MVVUVc+bMYdmyZbz66qutH6bVz98JUx988IE7cuRINzU11fX7/W5eXp57ww03uJ999pnpaGHnsccec4FjLvJ1EydOPOZ7tXTpUtPRjHvooYfczp07uz6fzx02bJj79ttvm44UlpYuXXrMv6GJEyeajhZ2jvdv02OPPWY6Wti59tpr3S5durg+n89NS0tzR48e7S5evNhIlqg/ZkRERETCW9ua6BcREZGwozIiIiIiRqmMiIiIiFEqIyIiImKUyoiIiIgYpTIiIiIiRqmMiIiIiFEqIyIiImKUyoiIiIgYpTIiIiIiRqmMiIiIiFEqIyLS6h555BGys7NxHKfB+KWXXsq1115rKJWImKIyIiKt7sorr2T//v0sXbq0fqy8vJxFixZx9dVXG0wmIiaojIhIq0tJSWHcuHHMmTOnfuyZZ56hY8eOjBw50mAyETFBZUREjLj66qt59tlnqampAeAf//gH3/ve97Bt/bMk0tZYruu6pkOISNtz5MgRMjIyeOyxxxg6dChdunRh9erVnHHGGaajiUgrUxkREWMmTZpEZWUlw4cP57HHHmPDhg2mI4mIAV7TAUSk7br66qu5+OKL+fjjj/nhD39oOo6IGKI9IyJijOM4dOrUiT179rB161ZOO+0005FExACVERERETFKh62LiIiIUSojIiIiYpTKiIiIiBilMiIiIiJGqYyIiIiIUSojIiIiYpTKiIiIiBilMiIiIiJGqYyIiIiIUSojIiIiYpTKiIiIiBj1/wGFrughmFgoCQAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "plt.plot(input_mat.at[:, 0].set(jnp.array(v))[:, 0], choice_p_list)\n",
    "plt.plot(input_mat.at[:, 0].set(jnp.array(v))[:, 0], np.exp(net_out))\n",
    "plt.xlabel(\"v\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[9.9949121e-01],\n",
       "       [9.9944276e-01],\n",
       "       [9.9937797e-01],\n",
       "       [9.9928951e-01],\n",
       "       [9.9916565e-01],\n",
       "       [9.9898762e-01],\n",
       "       [9.9872535e-01],\n",
       "       [9.9833125e-01],\n",
       "       [9.9773163e-01],\n",
       "       [9.9681586e-01],\n",
       "       [9.9542737e-01],\n",
       "       [9.9335217e-01],\n",
       "       [9.9030405e-01],\n",
       "       [9.8589939e-01],\n",
       "       [9.7961831e-01],\n",
       "       [9.7075099e-01],\n",
       "       [9.5833552e-01],\n",
       "       [9.4110185e-01],\n",
       "       [9.1743165e-01],\n",
       "       [8.8535011e-01],\n",
       "       [8.4261829e-01],\n",
       "       [7.8712934e-01],\n",
       "       [7.1787149e-01],\n",
       "       [6.3623816e-01],\n",
       "       [5.4648054e-01],\n",
       "       [4.5447949e-01],\n",
       "       [3.6594564e-01],\n",
       "       [2.8543937e-01],\n",
       "       [2.1616854e-01],\n",
       "       [1.5976143e-01],\n",
       "       [1.1601422e-01],\n",
       "       [8.3285108e-02],\n",
       "       [5.9331749e-02],\n",
       "       [4.1968789e-02],\n",
       "       [2.9401984e-02],\n",
       "       [2.0324167e-02],\n",
       "       [1.3843405e-02],\n",
       "       [9.3300734e-03],\n",
       "       [6.2859668e-03],\n",
       "       [4.2913300e-03],\n",
       "       [3.0071924e-03],\n",
       "       [2.1833957e-03],\n",
       "       [1.6502864e-03],\n",
       "       [1.2993332e-03],\n",
       "       [1.0632601e-03],\n",
       "       [9.0079551e-04],\n",
       "       [7.8651094e-04],\n",
       "       [7.0452545e-04],\n",
       "       [6.4473687e-04],\n",
       "       [6.0057559e-04]], dtype=float32)"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.exp(net_out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ssms_dev",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
